import torch
from torch.nn import functional as F


def amigo(params, hparams, val_data_list, args, out_f, reg_f, z):
    data_list, labels_list = val_data_list
    # Fy_gradient
    output = out_f(data_list[0], params)
    Fy_gradient = gradient_fy(args, labels_list[0], params, data_list[0], output)
    Fy_gradient = torch.reshape(Fy_gradient, [-1])

    # z update
    output = out_f(data_list[1], params)
    loss = F.cross_entropy(output, labels_list[1], reduction='none')
    loss_regu = reg_f(params, hparams, loss)
    grad_gy = torch.autograd.grad(loss_regu, params, create_graph=True)[0]
    grad_gy = torch.reshape(grad_gy, [-1])

    for _ in range(args.hessian_q):
        # HVP = torch.autograd.grad(grad_gy, params, grad_outputs=z, retain_graph=True)[0]
        HVP = torch.autograd.grad(torch.matmul(grad_gy,z), params, retain_graph=True)[0]
        HVP = torch.reshape(HVP, [-1])
        z = z - args.eta * (HVP - Fy_gradient)

    # Gyx_gradient
    output = out_f(data_list[2], params)
    Gy_gradient = gradient_gy(args, labels_list[2], params, data_list[2], hparams, output, reg_f)
    Gy_gradient = torch.reshape(Gy_gradient, [-1])
    Gyx_gradient = torch.autograd.grad(torch.matmul(Gy_gradient, z), hparams, retain_graph=True)[0]
    outer_update = -Gyx_gradient

    return outer_update



def LH_w_H(params, hparams, val_data_list, args, out_f, reg_f, z):
        data_list, labels_list = val_data_list
        # Fy_gradient
        output = out_f(data_list[0], params)
        Fy_gradient = gradient_fy(args, labels_list[0], params, data_list[0], output)
        Fy_gradient = torch.reshape(Fy_gradient, [-1])

        # z update
        output = out_f(data_list[1], params)
        loss = F.cross_entropy(output, labels_list[1], reduction='none')
        loss_regu = reg_f(params, hparams, loss)
        grad_gy = torch.autograd.grad(loss_regu, params, create_graph=True)[0]
        grad_gy = torch.reshape(grad_gy, [-1])
        # HVP = torch.autograd.grad(grad_gy, params, grad_outputs=z, retain_graph=True)[0]
        HVP = torch.autograd.grad(torch.matmul(grad_gy,z), params, retain_graph=True)[0]
        HVP = torch.reshape(HVP, [-1])
        z = z - args.eta * (HVP - Fy_gradient)

        # Gyx_gradient
        output = out_f(data_list[2], params)
        Gy_gradient = gradient_gy(args, labels_list[2], params, data_list[2], hparams, output, reg_f)
        Gy_gradient = torch.reshape(Gy_gradient, [-1])
        Gyx_gradient = torch.autograd.grad(torch.matmul(Gy_gradient, z), hparams, retain_graph=True)[0]
        outer_update = -Gyx_gradient

        return z, outer_update


def LH(params, hparams, val_data_list, args, out_f, reg_f, z):
    data_list, labels_list = val_data_list

    output = out_f(data_list[2], params)
    Gy_gradient = gradient_gy(args, labels_list[2], params, data_list[2], hparams, output, reg_f)
    Gy_gradient = torch.reshape(Gy_gradient, [-1])
    Gyx_gradient = torch.autograd.grad(torch.matmul(Gy_gradient, z), hparams, retain_graph=True)[0]
    outer_update = -Gyx_gradient

    return outer_update


def gradient_fy(args, labels, params, data, output):
    loss = F.cross_entropy(output, labels)
    grad = torch.autograd.grad(loss, params)[0]
    return grad

def gradient_gy(args, labels_cp, params, data, hparams, output, reg_f):
    # For MNIST data-hyper cleaning experiments
    loss = F.cross_entropy(output, labels_cp, reduction='none')
    # For NewsGroup l2reg expriments
    # loss = F.cross_entropy(output, labels_cp)
    loss_regu = reg_f(params, hparams, loss)
    grad = torch.autograd.grad(loss_regu, params, create_graph=True)[0]
    return grad

